""" Import libraries """
import argparse
import os

import copy
import numpy as np
from tensorforce.agents import Agent
from tensorforce.environments import Environment
import matplotlib.pyplot as plt
import tensorflow as tf

from model.env_rl_helper import RLModelHelper
from utils.utils_helper import UtilsHelper

""" Parser related """
parser = argparse.ArgumentParser(description="run_file")
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--respath', default=0, type=int)
args = parser.parse_args()
np.random.seed(args.seed)
tf.random.set_seed(args.seed)

"""
The experiment
"""
""" Load the environment and the RL agent """
if args.respath == 0:  # Test offline agents
    path = 'results'
elif args.respath == 1:  # Test online implemented agents
    path = 'results_online_implementation'
else:
    print('The argument for the result path is wrong. Please define it carefully')
    quit()

agent_directory = path + '/seed_{}_train'.format(args.seed)
# initial_states = np.load('data/initial_states_10k_episodes_smallCIS_seed20.npy')

env = RLModelHelper()  # For plant simulator
agent = Agent.load(directory=agent_directory, filename='ppo_cstr_0p0001lr_10000episodes')  # Load trained agent

""" Initialize simulation """
score_history = []
rewards_history = []
actions_history = []
states_history = []
fail_history = []
bad_state_transitions_history = []
n_episodes = 10000
num_sim = 200
bad_epi = 0

""" Test the RL agent """
for _ in range(n_episodes):
    env.actual_reset = True
    states = env.reset()
    # states = initial_states[_,:]
    # env.current_state = initial_states[_,:]
    rewards_episode = []
    actions_episode = []
    states_episode = [states]
    bad_state_transitions_episode = []
    score = 0
    reset_x = 0  # How many times the agent reset the state.
    internals = agent.initial_internals()
    for t in range(num_sim):
        previous_state = copy.deepcopy(states)
        actions, internals = agent.act(states=states, internals=internals, independent=True)
        states, terminal, reward = env.execute(actions=actions)
        score += reward
        rewards_episode.append(reward)
        actions_episode.append(actions['Tc'])
        states_episode.append(states)
        bad_state_transitions_episode.append([env.timestep, previous_state, actions['Tc'], reward, states])
        if (np.matmul(env.hrepABig, states) > env.hrepBBig).any():  # This also include the case that states is violating the physical constraint
            score -= reward
            score += -1000
            states = copy.deepcopy(previous_state)
            env.current_state = copy.deepcopy(previous_state)
            reset_x += 1

    if reset_x != 0:  # Resetting x happened
        bad_epi += 1
        bad_state_transitions_history.append(bad_state_transitions_episode)
    else:
        bad_state_transitions_history.append([])
    print(bad_epi, 'out of', _+1)
    score_history.append(score)
    fail_history.append(reset_x)
    # if _%100 == 0:
    rewards_history.append(rewards_episode)
    actions_history.append(actions_episode)
    states_history.append(states_episode)
print('Bad episode results (test):', bad_epi, 'out of', _+1)
agent.close()
env.close()

""" Save the agent and the reward trajectory """
fname = agent.spec['agent'] + '_cstr_' + '0p0001' + 'lr_' + str(n_episodes) + 'episodes'
respth = path+'/seed_{}_agent_test_on_seed_{}/'
dir = respth.format(args.seed, args.seed)
try:
    os.mkdir(dir)
except OSError as error:
    print(error)
scores_file = dir + 'scores_' + fname
rewards_file = dir + 'rewards_' + fname
actions_file = dir + 'actions_' + fname
states_file = dir + 'states_' + fname
fail_file = dir + 'fail_' + fname
bad_states_file = dir + 'bad_states_' + fname
np.save(scores_file, score_history)
np.save(rewards_file, rewards_history)
np.save(actions_file, actions_history)
np.save(states_file, states_history)
np.save(bad_states_file, bad_state_transitions_history)
